{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "gJEMjPgeI-rw",
    "outputId": "7491c067-b1be-4505-b3f5-19ba4c00a593"
   },
   "outputs": [],
   "source": [
    "!pip install moshi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "CA4K5iDFJcqJ",
    "outputId": "b609843a-a193-4729-b099-5f8780532333"
   },
   "outputs": [],
   "source": [
    "!wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VA3Haix3IZ8Q"
   },
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "import time\n",
    "import sentencepiece\n",
    "import sphn\n",
    "import textwrap\n",
    "import torch\n",
    "\n",
    "from moshi.models import loaders, MimiModel, LMModel, LMGen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9AK5zBMTI9bw"
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class InferenceState:\n",
    "    mimi: MimiModel\n",
    "    text_tokenizer: sentencepiece.SentencePieceProcessor\n",
    "    lm_gen: LMGen\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        mimi: MimiModel,\n",
    "        text_tokenizer: sentencepiece.SentencePieceProcessor,\n",
    "        lm: LMModel,\n",
    "        batch_size: int,\n",
    "        device: str | torch.device,\n",
    "    ):\n",
    "        self.mimi = mimi\n",
    "        self.text_tokenizer = text_tokenizer\n",
    "        self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)\n",
    "        self.device = device\n",
    "        self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)\n",
    "        self.batch_size = batch_size\n",
    "        self.mimi.streaming_forever(batch_size)\n",
    "        self.lm_gen.streaming_forever(batch_size)\n",
    "\n",
    "    def run(self, in_pcms: torch.Tensor):\n",
    "        ntokens = 0\n",
    "        first_frame = True\n",
    "        chunks = [\n",
    "            c\n",
    "            for c in in_pcms.split(self.frame_size, dim=2)\n",
    "            if c.shape[-1] == self.frame_size\n",
    "        ]\n",
    "        start_time = time.time()\n",
    "        all_text = []\n",
    "        for chunk in chunks:\n",
    "            codes = self.mimi.encode(chunk)\n",
    "            if first_frame:\n",
    "                # Ensure that the first slice of codes is properly seen by the transformer\n",
    "                # as otherwise the first slice is replaced by the initial tokens.\n",
    "                tokens = self.lm_gen.step(codes)\n",
    "                first_frame = False\n",
    "            tokens = self.lm_gen.step(codes)\n",
    "            if tokens is None:\n",
    "                continue\n",
    "            assert tokens.shape[1] == 1\n",
    "            one_text = tokens[0, 0].cpu()\n",
    "            if one_text.item() not in [0, 3]:\n",
    "                text = self.text_tokenizer.id_to_piece(one_text.item())\n",
    "                text = text.replace(\"▁\", \" \")\n",
    "                all_text.append(text)\n",
    "            ntokens += 1\n",
    "        dt = time.time() - start_time\n",
    "        print(\n",
    "            f\"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step\"\n",
    "        )\n",
    "        return \"\".join(all_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 353,
     "referenced_widgets": [
      "0a5f6f887e2b4cd1990a0e9ec0153ed9",
      "f7893826fcba4bdc87539589d669249b",
      "8805afb12c484781be85082ff02dad13",
      "97679c0d9ab44bed9a3456f2fcb541fd",
      "d73c0321bed54a52b5e1da0a7788e32a",
      "d67be13a920d4fc89e5570b5b29fc1d2",
      "6b377c2d7bf945fb89e46c39d246a332",
      "b82ff365c78e41ad8094b46daf79449d",
      "477aa7fa82dc42d5bce6f1743c45d626",
      "cbd288510c474430beb66f346f382c45",
      "aafc347cdf28428ea6a7abe5b46b726f",
      "fca09acd5d0d45468c8b04bfb2de7646",
      "79e35214b51b4a9e9b3f7144b0b34f7b",
      "89e9a37f69904bd48b954d627bff6687",
      "57028789c78248a7b0ad4f031c9545c9",
      "1150fcb427994c2984d4d0f4e4745fe5",
      "e24b1fc52f294f849019c9b3befb613f",
      "8724878682cf4c3ca992667c45009398",
      "36a22c977d5242008871310133b7d2af",
      "5b3683cad5cb4877b43fadd003edf97f",
      "703f98272e4d469d8f27f5a465715dd8",
      "9dbe02ef5fac41cfaee3d02946e65c88",
      "37faa87ad03a4271992c21ce6a629e18",
      "570c547e48cd421b814b2c5e028e4c0b",
      "b173768580fc4c0a8e3abf272e4c363a",
      "e57d1620f0a9427b85d8b4885ef4e8e3",
      "5dd4474df70743498b616608182714dd",
      "cc907676a65f4ad1bf68a77b4a00e89b",
      "a34abc3b118e4305951a466919c28ff6",
      "a77ccfcdb90146c7a63b4b2d232bc494",
      "f7313e6e3a27475993cab3961d6ae363",
      "39b47fad9c554839868fe9e4bbf7def2",
      "14e9511ea0bd44c49f0cf3abf1a6d40e",
      "a4ea8e0c4cac4d5e88b7e3f527e4fe90",
      "571afc0f4b2840c9830d6b5a307ed1f9",
      "6ec593cab5b64f0ea638bb175b9daa5c",
      "77a52aed00ae408bb24524880e19ec8a",
      "0b2de4b29b4b44fe9d96361a40c793d0",
      "3c5b5fb1a5ac468a89c1058bd90cfb58",
      "e53e0a2a240e43cfa562c89b3d703dea",
      "35966343cf9249ef8bc028a0d5c5f97d",
      "e36a37e0d41c47ccb8bc6d56c19fb17c",
      "279ccf7de43847a1a6579c9182a46cc8",
      "41b5d6ab0b7d43c790a55f125c0e7494"
     ]
    },
    "id": "UsQJdAgkLp9n",
    "outputId": "9b7131c3-69c5-4323-8312-2ce7621d8869"
   },
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en\n",
    "checkpoint_info = loaders.CheckpointInfo.from_hf_repo(\"kyutai/stt-1b-en_fr\")\n",
    "mimi = checkpoint_info.get_mimi(device=device)\n",
    "text_tokenizer = checkpoint_info.get_text_tokenizer()\n",
    "lm = checkpoint_info.get_moshi(device=device)\n",
    "in_pcms, _ = sphn.read(\"sample_fr_hibiki_crepes.mp3\", sample_rate=mimi.sample_rate)\n",
    "in_pcms = torch.from_numpy(in_pcms).to(device=device)\n",
    "\n",
    "stt_config = checkpoint_info.stt_config\n",
    "pad_left = int(stt_config.get(\"audio_silence_prefix_seconds\", 0.0) * 24000)\n",
    "pad_right = int((stt_config.get(\"audio_delay_seconds\", 0.0) + 1.0) * 24000)\n",
    "in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode=\"constant\")\n",
    "in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)\n",
    "\n",
    "state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)\n",
    "text = state.run(in_pcms)\n",
    "print(textwrap.fill(text, width=100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 75
    },
    "id": "CIAXs9oaPrtj",
    "outputId": "94cc208c-2454-4dd4-a64e-d79025144af5"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Audio\n",
    "\n",
    "Audio(\"sample_fr_hibiki_crepes.mp3\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qkUZ6CBKOdTa"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "L4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
